import foolbox as fb
import tensorflow as tf
import tensorflow_compression
from tensorflow.keras import layers, models
from tensorflow.keras.applications import ResNet50
import numpy as np
import argparse

attack = fb.attacks.LinfPGD(steps=10)
NUM_IMAGES = 50000

def compute_accuracy(model, dataset):
    correct_predictions = 0
    total_images = 0
    # Loop over the dataset and run the attack on each image
    for images, labels in dataset:
        # Run the model on the images
        predictions = model(images)
        predicted_labels = np.argmax(predictions, axis=1)
        correct_predictions += np.sum(predicted_labels == labels.numpy())
        total_images += images.shape[0]
        
    # Print overall success rate
    accuracy = correct_predictions / total_images * 100
    print(f"Accuracy: {accuracy:.2f}%")
    return accuracy

def run_attack(model_attack,model_defense,dataset):
    correct_predictions = [0,0,0]
    total_images = 0
    fmodel = fb.TensorFlowModel(model_attack, bounds=(-257.0, 257.0))
    # Loop over the dataset and run the attack on each image
    for images, labels in dataset:
        # Make sure the images are preprocessed for ResNet50

        # Run the PGD attack on the pipeline
        raw_advs, clipped_advs, success = attack(fmodel, images, labels,epsilons=[4,8,16])
        
        # Count successful attacks
        predictions4 = model_defense(clipped_advs[0])
        predictions8 = model_defense(clipped_advs[1])
        predictions16 = model_defense(clipped_advs[2])
        predicted_labels4 = np.argmax(predictions4, axis=1)
        predicted_labels8 = np.argmax(predictions8, axis=1)
        predicted_labels16 = np.argmax(predictions16, axis=1)
        correct_predictions[0] += np.sum(predicted_labels4 == labels.numpy())
        correct_predictions[1] += np.sum(predicted_labels8 == labels.numpy())
        correct_predictions[2] += np.sum(predicted_labels16 == labels.numpy())
        total_images += images.shape[0]

        print(f"Processed {total_images} images,accuracy: {correct_predictions[0] / total_images * 100:.2f}%")

        

    
    # Print overall success rate
    robust_accuracy = [correct_predictions[i] / total_images * 100 for i in range(len(correct_predictions))]
    print(f"Robust Accuracy: {robust_accuracy}")
    return robust_accuracy

def run_exp(beta,attack_beta,dataset,model_path,through = True):
    resnet50 = ResNet50(weights='imagenet')
    model = tf.keras.models.load_model(f'data/{model_path}')
    model.beta.assign(beta)
    
    model_a = tf.keras.models.load_model(f'data/{model_path}')
    model_a.beta.assign(attack_beta)
    #create a pipeline model: preprocess -> ResNet50
    mean = tf.constant([103.939, 116.779, 123.68])
    raw_image = layers.Input(shape=(224, 224, 3))
    preprocessed_image = raw_image[..., ::-1]
    normalized_image = preprocessed_image - mean
    output= resnet50(normalized_image)
    ResNet = models.Model(inputs=raw_image, outputs=output)
    
    # Create a pipeline model: MRIC -> ResNet50
    input_tensor = layers.Input(shape=(224, 224, 3))
    beta = tf.zeros(1)+beta
    x = model([tf.image.resize(input_tensor, (256, 256)),beta],training=True)[0][2]  # Pass through your custom model
    output_tensor = ResNet(tf.image.resize(x, (224, 224)))  # Pass the output of your model to ResNet50
    pipeline_model = models.Model(inputs=input_tensor, outputs=output_tensor)

    # Create a pipeline model for the attack: MRIC -> ResNet50
    input_tensor_a = layers.Input(shape=(224, 224, 3))
    beta_a = tf.zeros(1)+attack_beta
    x_a = model_a([tf.image.resize(input_tensor_a, (256, 256)),beta_a],training=True)[0][2]  # Pass through your custom model
    output_tensor_a = ResNet(tf.image.resize(x_a, (224, 224)))  # Pass the output of your model to ResNet50
    pipeline_model_a = models.Model(inputs=input_tensor_a, outputs=output_tensor_a)
    
    clean_accuracy = compute_accuracy(pipeline_model, dataset)
    standard_accuracy = compute_accuracy(ResNet, dataset)
    wb = run_attack(pipeline_model_a,pipeline_model,dataset)
    bb = run_attack(ResNet,pipeline_model,dataset)

    return wb,bb, clean_accuracy, standard_accuracy


def main():
    # Load the dataset
    argparser = argparse.ArgumentParser(description='Run MRIC experiment')
    argparser.add_argument('--model_path', type=str, default='MRIC', help='Model to use for the experiment')
    argparser.add_argument('--batch_size', type=int, default=1, help='Batch size for the dataset')
    argparser.add_argument('--beta', type=float, help='Beta value for the defense')
    argparser.add_argument('--beta_attack', type=float, default=-1, help='Beta value for the attack')
    args = argparser.parse_args()
    BATCH_SIZE = args.batch_size
    beta = args.beta
    beta_attack = args.beta_attack
    if beta_attack == -1:
        beta_attack = beta
    model_path = args.model_path
    # Load the dataset


    dataset = tf.keras.preprocessing.image_dataset_from_directory(
        'data/val',
        labels='inferred',
        label_mode='int',         # Integer labels
        image_size=(224, 224),     # match your model input
        batch_size=BATCH_SIZE,
        shuffle=True,
        seed=123
    )

    num_batches = NUM_IMAGES // BATCH_SIZE
    dataset = dataset.take(num_batches)

    
    # Run the experiment with different beta values
    
    print(f"Running experiment with beta: {beta}")
    # Run the experiment with the specified beta value
    wb,bb, clean_accuracy, standard_accuracy = run_exp(beta,beta_attack,dataset,model_path)
    # Save the results as csv
    with open(f'results/{model_path}results.csv', 'a') as f:
        f.write(f" defense beta: {beta}, attack beta: {beta_attack},{standard_accuracy},{clean_accuracy},{wb},{bb}\n")
    print(f"Results saved to results.csv")
    # Print the results
    print(f"Beta: {beta}, Without Defense: {bb:.2f}%, With Defense: {wb:.2f}%")
    

if __name__ == "__main__":
    main()